# A utility to split large images and their VOC format annotations into smaller tiles while maintaining proper object annotations.

import os
from PIL import Image
import xml.etree.ElementTree as ET
from xml.dom.minidom import parseString


def parse_voc_xml(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()

    objects = []
    filename = root.find('filename').text

    for obj in root.findall('object'):
        obj_struct = {}
        obj_struct['class'] = obj.find('name').text
        bbox = obj.find('bndbox')
        obj_struct['xmin'] = int(bbox.find('xmin').text)
        obj_struct['ymin'] = int(bbox.find('ymin').text)
        obj_struct['xmax'] = int(bbox.find('xmax').text)
        obj_struct['ymax'] = int(bbox.find('ymax').text)
        objects.append(obj_struct)

    return filename, objects


def create_voc_xml(filename, objects, width, height, folder='images'):
    annotation = ET.Element('annotation')

    folder_elem = ET.SubElement(annotation, 'folder')
    folder_elem.text = folder

    filename_elem = ET.SubElement(annotation, 'filename')
    filename_elem.text = filename

    path_elem = ET.SubElement(annotation, 'path')
    path_elem.text = os.path.join(folder, filename)

    source_elem = ET.SubElement(annotation, 'source')
    database_elem = ET.SubElement(source_elem, 'database')
    database_elem.text = 'Unknown'

    size_elem = ET.SubElement(annotation, 'size')
    width_elem = ET.SubElement(size_elem, 'width')
    width_elem.text = str(width)
    height_elem = ET.SubElement(size_elem, 'height')
    height_elem.text = str(height)
    depth_elem = ET.SubElement(size_elem, 'depth')
    depth_elem.text = '3'

    segmented_elem = ET.SubElement(annotation, 'segmented')
    segmented_elem.text = '0'

    for obj in objects:
        obj_elem = ET.SubElement(annotation, 'object')

        name_elem = ET.SubElement(obj_elem, 'name')
        name_elem.text = obj['class']

        pose_elem = ET.SubElement(obj_elem, 'pose')
        pose_elem.text = 'Unspecified'

        truncated_elem = ET.SubElement(obj_elem, 'truncated')
        truncated_elem.text = '0'

        difficult_elem = ET.SubElement(obj_elem, 'difficult')
        difficult_elem.text = '0'

        bndbox_elem = ET.SubElement(obj_elem, 'bndbox')

        xmin_elem = ET.SubElement(bndbox_elem, 'xmin')
        xmin_elem.text = str(obj['xmin'])

        ymin_elem = ET.SubElement(bndbox_elem, 'ymin')
        ymin_elem.text = str(obj['ymin'])

        xmax_elem = ET.SubElement(bndbox_elem, 'xmax')
        xmax_elem.text = str(obj['xmax'])

        ymax_elem = ET.SubElement(bndbox_elem, 'ymax')
        ymax_elem.text = str(obj['ymax'])

    rough_string = ET.tostring(annotation, 'utf-8')
    reparsed = parseString(rough_string)
    return reparsed.toprettyxml(indent="  ")


def split_image_and_annotations(xml_dir, images_dir, output_image_dir, output_xml_dir, tile_size=512):
    # Create output directories if they do not exist
    os.makedirs(output_image_dir, exist_ok=True)
    os.makedirs(output_xml_dir, exist_ok=True)

    for xml_file in os.listdir(xml_dir):
        if not xml_file.endswith('.xml'):
            continue

        xml_path = os.path.join(xml_dir, xml_file)
        filename, objects = parse_voc_xml(xml_path)
        image_path = os.path.join(images_dir, filename)

        try:
            image = Image.open(image_path)
        except FileNotFoundError:
            print(f"Image {filename} not found in {images_dir}. Skipping.")
            continue

        image_width, image_height = image.size

        # Split image into tiles
        for i in range(0, image_width, tile_size):
            for j in range(0, image_height, tile_size):
                if i + tile_size > image_width or j + tile_size > image_height:
                    # Skip tiles that would be smaller than tile_size x tile_size
                    continue

                tile = image.crop((i, j, i + tile_size, j + tile_size))
                tile_objects = []
                for obj in objects:
                    xmin, ymin, xmax, ymax = obj['xmin'], obj['ymin'], obj['xmax'], obj['ymax']

                    # Check if the object is within the tile or intersects with the tile
                    if xmax > i and xmin < i + tile_size and ymax > j and ymin < j + tile_size:
                        # Calculate the intersection area
                        inter_xmin = max(xmin, i)
                        inter_ymin = max(ymin, j)
                        inter_xmax = min(xmax, i + tile_size)
                        inter_ymax = min(ymax, j + tile_size)
                        intersection_area = (inter_xmax - inter_xmin) * (inter_ymax - inter_ymin)

                        # Calculate the original object area
                        original_area = (xmax - xmin) * (ymax - ymin)

                        # Retain the object only if the intersection area is greater than 10% of the original area
                        if intersection_area / original_area > 0.3:
                            obj_copy = obj.copy()
                            obj_copy['xmin'] = inter_xmin - i
                            obj_copy['ymin'] = inter_ymin - j
                            obj_copy['xmax'] = inter_xmax - i
                            obj_copy['ymax'] = inter_ymax - j
                            tile_objects.append(obj_copy)

                if not tile_objects:
                    continue

                tile_filename = f"{os.path.splitext(filename)[0]}_{i}_{j}.jpg"
                tile_output_path = os.path.join(output_image_dir, tile_filename)
                tile.save(tile_output_path)

                xml_content = create_voc_xml(tile_filename, tile_objects, tile_size, tile_size)
                xml_filename = os.path.join(output_xml_dir, os.path.splitext(tile_filename)[0] + '.xml')
                with open(xml_filename, 'w') as f:
                    f.write(xml_content)


xml_dir = r'D:/***'  # Replace with your XML file directory path
images_dir = r'D:/***'  # Replace with your image directory path
output_image_dir = r'D:/***'  # Output directory for images
output_xml_dir = r'D:/***'  # Output directory for XML files
split_image_and_annotations(xml_dir, images_dir, output_image_dir, output_xml_dir)
